// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package flight

import (
	"context"
	"encoding/base64"
	"io"
	"runtime"
	"strings"
	"sync/atomic"

	"golang.org/x/xerrors"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
)

// Client is an interface wrapped around the generated FlightServiceClient which is
// generated by grpc protobuf definitions. This interface provides a useful hiding
// of the authentication handshake via calling Authenticate and using the
// ClientAuthHandler rather than manually having to implement the grpc communication
// and sending of the auth token.
type Client interface {
	// Authenticate uses the ClientAuthHandler that was used when creating the client
	// in order to use the Handshake endpoints of the service.
	Authenticate(context.Context, ...grpc.CallOption) error
	AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
	Close() error
	// join the interface from the FlightServiceClient instead of re-defining all
	// the endpoints here.
	FlightServiceClient
}

type CustomClientMiddleware interface {
	StartCall(ctx context.Context) context.Context
}

type ClientPostCallMiddleware interface {
	CallCompleted(ctx context.Context, err error)
}

type ClientHeadersMiddleware interface {
	HeadersReceived(ctx context.Context, md metadata.MD)
}

func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware {
	return ClientMiddleware{
		Unary: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
			nctx := middleware.StartCall(ctx)
			if nctx != nil {
				ctx = nctx
			}

			if hdrs, ok := middleware.(ClientHeadersMiddleware); ok {
				hdrmd := make(metadata.MD)
				trailermd := make(metadata.MD)
				opts = append(opts, grpc.Header(&hdrmd), grpc.Trailer(&trailermd))
				defer func() {
					hdrs.HeadersReceived(ctx, metadata.Join(hdrmd, trailermd))
				}()
			}

			err := invoker(ctx, method, req, reply, cc, opts...)
			if post, ok := middleware.(ClientPostCallMiddleware); ok {
				post.CallCompleted(ctx, err)
			}
			return err
		},
		Stream: func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
			nctx := middleware.StartCall(ctx)
			if nctx != nil {
				ctx = nctx
			}

			cs, err := streamer(ctx, desc, cc, method, opts...)
			hdrs, isHdrs := middleware.(ClientHeadersMiddleware)
			post, isPostcall := middleware.(ClientPostCallMiddleware)
			if !isPostcall && !isHdrs {
				return cs, err
			}

			if err != nil {
				if isHdrs {
					md, _ := cs.Header()
					hdrs.HeadersReceived(ctx, metadata.Join(md, cs.Trailer()))
				}
				if isPostcall {
					post.CallCompleted(ctx, err)
				}
				return cs, err
			}

			// Grab the client stream context because when the finish function or the goroutine below will be
			// executed it's not guaranteed cs.Context() will be valid.
			csCtx := cs.Context()
			finishChan := make(chan struct{})
			isFinished := new(int32)
			*isFinished = 0
			finishFunc := func(err error) {

				// since there are multiple code paths that could call finishFunc
				// we need some sort of synchronization to guard against multiple
				// calls to finish
				if !atomic.CompareAndSwapInt32(isFinished, 0, 1) {
					return
				}

				close(finishChan)
				if isPostcall {
					post.CallCompleted(csCtx, err)
				}
				if isHdrs {
					hdrmd, _ := cs.Header()
					hdrs.HeadersReceived(csCtx, metadata.Join(hdrmd, cs.Trailer()))
				}
			}
			go func() {
				select {
				case <-finishChan:
					// finish is being called by something else, no action necessary
				case <-csCtx.Done():
					finishFunc(csCtx.Err())
				}
			}()

			newCS := &clientStream{
				ClientStream: cs,
				desc:         desc,
				finishFn:     finishFunc,
			}
			// The `ClientStream` interface allows one to omit calling `Recv` if it's
			// known that the result will be `io.EOF`. See
			// http://stackoverflow.com/q/42915337
			// In such cases, there's nothing that triggers the span to finish. We,
			// therefore, set a finalizer so that the span and the context goroutine will
			// at least be cleaned up when the garbage collector is run.
			runtime.SetFinalizer(newCS, func(newcs *clientStream) {
				newcs.finishFn(nil)
			})
			return newCS, nil
		},
	}
}

type clientStream struct {
	grpc.ClientStream
	desc     *grpc.StreamDesc
	finishFn func(error)
}

func (cs *clientStream) Header() (metadata.MD, error) {
	md, err := cs.ClientStream.Header()
	if err != nil {
		cs.finishFn(err)
	}
	return md, err
}

func (cs *clientStream) SendMsg(m interface{}) error {
	err := cs.ClientStream.SendMsg(m)
	if err != nil {
		cs.finishFn(err)
	}
	return err
}

func (cs *clientStream) RecvMsg(m interface{}) error {
	err := cs.ClientStream.RecvMsg(m)
	if err == io.EOF {
		cs.finishFn(nil)
		return err
	} else if err != nil {
		cs.finishFn(err)
		return err
	}

	if !cs.desc.ServerStreams {
		cs.finishFn(nil)
	}
	return err
}

func (cs *clientStream) CloseSend() error {
	err := cs.ClientStream.CloseSend()
	if err != nil {
		cs.finishFn(err)
	}
	return err
}

type ClientMiddleware struct {
	Stream grpc.StreamClientInterceptor
	Unary  grpc.UnaryClientInterceptor
}

type client struct {
	conn        *grpc.ClientConn
	authHandler ClientAuthHandler

	FlightServiceClient
}

// NewFlightClient takes in the address of the grpc server and an auth handler for the
// application-level handshake. If using TLS or other grpc configurations they can still
// be passed via the grpc.DialOption list just as if connecting manually without this
// helper function.
//
// Alternatively, a grpc client can be constructed as normal without this helper as the
// grpc generated client code is still exported. This exists to add utility and helpers
// around the authentication and passing the token with requests.
//
// Deprecated: prefer to use NewClientWithMiddleware
func NewFlightClient(addr string, auth ClientAuthHandler, opts ...grpc.DialOption) (Client, error) {
	if auth != nil {
		opts = append([]grpc.DialOption{
			grpc.WithChainStreamInterceptor(createClientAuthStreamInterceptor(auth)),
			grpc.WithChainUnaryInterceptor(createClientAuthUnaryInterceptor(auth)),
		}, opts...)
	}

	conn, err := grpc.Dial(addr, opts...)
	if err != nil {
		return nil, err
	}

	return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
}

// NewClientWithMiddleware takes a slice of middlewares in addition to the auth and address which will be
// used by grpc and chained, the first middleware will be the outer most with the last middleware
// being the inner most wrapper around the actual call. It also passes along the dialoptions passed in such
// as TLS certs and so on.
func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
	unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware))
	stream := make([]grpc.StreamClientInterceptor, 0, len(middleware))
	if auth != nil {
		unary = append(unary, createClientAuthUnaryInterceptor(auth))
		stream = append(stream, createClientAuthStreamInterceptor(auth))
	}
	if len(middleware) > 0 {
		for _, m := range middleware {
			if m.Unary != nil {
				unary = append(unary, m.Unary)
			}
			if m.Stream != nil {
				stream = append(stream, m.Stream)
			}
		}
	}
	opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), grpc.WithChainStreamInterceptor(stream...))
	conn, err := grpc.Dial(addr, opts...)
	if err != nil {
		return nil, err
	}

	return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
}

func (c *client) AuthenticateBasicToken(ctx context.Context, username, password string, opts ...grpc.CallOption) (context.Context, error) {
	authCtx := metadata.AppendToOutgoingContext(ctx, "Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte(strings.Join([]string{username, password}, ":"))))

	stream, err := c.FlightServiceClient.Handshake(authCtx, opts...)
	if err != nil {
		return ctx, err
	}

	header, err := stream.Header()
	if err != nil {
		return ctx, err
	}

	_, err = stream.Recv()
	if err != nil && err != io.EOF {
		return ctx, err
	}

	err = stream.CloseSend()
	if err != nil {
		return ctx, err
	}

	meta := stream.Trailer()
	md := metadata.Join(header, meta)
	for _, token := range md.Get("authorization") {
		if token != "" {
			return metadata.AppendToOutgoingContext(ctx, "Authorization", token), nil
		}
	}

	return ctx, xerrors.Errorf("flight: no authorization header on the response")
}

func (c *client) Authenticate(ctx context.Context, opts ...grpc.CallOption) error {
	if c.authHandler == nil {
		return status.Error(codes.NotFound, "cannot authenticate without an auth-handler")
	}

	stream, err := c.FlightServiceClient.Handshake(ctx, opts...)
	if err != nil {
		return err
	}

	return c.authHandler.Authenticate(ctx, &clientAuthConn{stream})
}

func (c *client) Close() error {
	c.FlightServiceClient = nil
	return c.conn.Close()
}
